{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import mmengine\n",
    "import pickle\n",
    "from sklearn.cluster import KMeans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# info_path = '../data/CODA_motion_modes.pkl'\n",
    "# info_path = '../data/CODA_test.pkl'\n",
    "# info_path = '../data/CODA/coda_traj_train.pkl'\n",
    "info_path = '../data/CODA/coda_track_results.pkl'\n",
    "info = mmengine.load(info_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "200"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(info[0]['data_list'][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'track_ids': [0],\n",
       " 'track_bboxes': [array([-109.50052607,  360.62135559,    4.52157606,    0.74414235,\n",
       "            0.71282548,    1.65397727,   -1.74968522,    0.61664706])],\n",
       " 'track_states': ['alive_1_0'],\n",
       " 'track_labels': [1],\n",
       " 'gt_bboxes': [array([-109.73163116,  361.45533251,    4.52523609,    0.99730501,\n",
       "            0.83391031,    1.76094715,   -1.54632966]),\n",
       "  array([-109.50542925,  360.59398011,    4.52043673,    0.99730501,\n",
       "            0.81961932,    1.76094715,   -1.54632966])],\n",
       " 'gt_labels': [1, 1],\n",
       " 'gt_ids': ['Pedestrian:4', 'Pedestrian:5']}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "info[0]['data_list'][1][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'10_5815'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "info[0]['token']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_info_path = '../data/CODA_train.pkl'\n",
    "f = open(train_info_path, 'rb')\n",
    "train_info = pickle.load(f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_len = 8\n",
    "pred_len = 12\n",
    "n_clusters = 50\n",
    "results = []\n",
    "for data in train_info:\n",
    "    traj = np.concatenate([np.array(data[0]), np.array(data[1])], axis=0)\n",
    "    traj = traj - traj[obs_len - 1]\n",
    "    ref = traj[0]\n",
    "    angle = np.arctan2(ref[1], ref[0])\n",
    "    rot_mat = np.array([[np.cos(angle), -np.sin(angle)],\n",
    "                        [np.sin(angle), np.cos(angle)]])\n",
    "    traj = np.dot(traj, rot_mat.T)\n",
    "    results.append(traj)\n",
    "results = np.array(results)[:100000]\n",
    "cluster_data = results[:, obs_len:].reshape(results.shape[0], -1)\n",
    "clf = KMeans(n_clusters=n_clusters).fit(cluster_data)\n",
    "motion_modes = clf.cluster_centers_.reshape(n_clusters, -1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100000"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_range = [np.inf, np.inf, -np.inf, -np.inf]\n",
    "pos_idx = [-1, -1, -1, -1]\n",
    "for i, result in enumerate(results):\n",
    "    mean_pos = np.mean(result[obs_len:], axis=0)\n",
    "    if mean_pos[0] < pos_range[0]:\n",
    "        pos_range[0] = mean_pos[0]\n",
    "        pos_idx[0] = i\n",
    "    if mean_pos[0] > pos_range[2]:\n",
    "        pos_range[2] = mean_pos[0]\n",
    "        pos_idx[2] = i\n",
    "    if mean_pos[1] < pos_range[1]:\n",
    "        pos_range[1] = mean_pos[1]\n",
    "        pos_idx[1] = i\n",
    "    if mean_pos[1] > pos_range[3]:\n",
    "        pos_range[3] = mean_pos[1]\n",
    "        pos_idx[3] = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[ 575376.06, -509707.1 ],\n",
       "        [ 575376.06, -509707.1 ],\n",
       "        [ 575376.06, -509707.1 ],\n",
       "        [ 575376.06, -509707.1 ],\n",
       "        [ 575376.06, -509707.1 ],\n",
       "        [ 575376.06, -509707.1 ],\n",
       "        [ 571229.44, -510410.44],\n",
       "        [ 567343.44, -510973.84]], dtype=float32),\n",
       " array([[ 559475.3 , -503173.12],\n",
       "        [ 558330.75, -495851.66],\n",
       "        [ 554147.44, -491298.16],\n",
       "        [ 549872.2 , -486682.72],\n",
       "        [ 541826.1 , -483952.75],\n",
       "        [ 533105.3 , -483064.16],\n",
       "        [ 527820.6 , -480244.2 ],\n",
       "        [ 522496.22, -477364.84],\n",
       "        [ 522239.72, -470708.34],\n",
       "        [ 515234.38, -461643.8 ],\n",
       "        [ 502435.8 , -455782.94],\n",
       "        [ 489720.03, -449943.06]], dtype=float32),\n",
       " array([[[ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [-9.3581192e+01,  3.3630646e+02],\n",
       "         [-9.3484955e+01,  3.3630801e+02],\n",
       "         [-9.3394157e+01,  3.3630957e+02]],\n",
       " \n",
       "        [[ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [-9.3197411e+01,  3.3761377e+02],\n",
       "         [-9.3146835e+01,  3.3760815e+02],\n",
       "         [-9.3083626e+01,  3.3760419e+02]],\n",
       " \n",
       "        [[ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [-7.9999001e+01,  3.3617044e+02],\n",
       "         [-7.9871361e+01,  3.3616190e+02],\n",
       "         [-7.9754189e+01,  3.3613351e+02]],\n",
       " \n",
       "        [[ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [-7.4057755e+01,  3.3797708e+02],\n",
       "         [-7.4185341e+01,  3.3795081e+02],\n",
       "         [-7.4232376e+01,  3.3794504e+02]],\n",
       " \n",
       "        [[ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [-7.7247864e+01,  3.3767816e+02],\n",
       "         [-7.7371185e+01,  3.3767093e+02],\n",
       "         [-7.7486107e+01,  3.3767163e+02]],\n",
       " \n",
       "        [[ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [ 1.0000000e+09,  1.0000000e+09],\n",
       "         [-8.8115738e+01,  3.3492267e+02],\n",
       "         [-8.8229538e+01,  3.3487866e+02],\n",
       "         [-8.8347504e+01,  3.3485928e+02]]], dtype=float32))"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_info[55834]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  7737.2383,   2502.5723],\n",
       "       [  7737.2383,   2502.5723],\n",
       "       [  7737.2383,   2502.5723],\n",
       "       [  7737.2383,   2502.5723],\n",
       "       [  7737.2383,   2502.5723],\n",
       "       [  7737.2383,   2502.5723],\n",
       "       [  3750.7966,   1161.8721],\n",
       "       [     0.    ,      0.    ],\n",
       "       [ -8987.235 ,   6479.8306],\n",
       "       [-11258.331 ,  13533.628 ],\n",
       "       [-16099.9   ,  17379.883 ],\n",
       "       [-21041.93  ,  21272.998 ],\n",
       "       [-29415.035 ,  22716.262 ],\n",
       "       [-38167.81  ,  22235.52  ],\n",
       "       [-43827.26  ,  24197.838 ],\n",
       "       [-49535.203 ,  26212.621 ],\n",
       "       [-50825.492 ,  32747.904 ],\n",
       "       [-59157.35  ,  40610.52  ],\n",
       "       [-72712.66  ,  44406.145 ],\n",
       "       [-86182.92  ,  48193.92  ]], dtype=float32)"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results[55834]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([-40600.926, -120.61984, 16719.006, 55540.035], [55834, 87964, 56105, 55897])"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pos_range, pos_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-2.4663570e+02,  6.6555825e+03],\n",
       "       [-4.2536381e+01,  1.1740108e+04],\n",
       "       [ 1.1838752e+02,  1.7078488e+04],\n",
       "       [-6.0090210e+03,  2.4203889e+04],\n",
       "       [-5.2982822e+03,  3.0262061e+04],\n",
       "       [-1.6728844e+04,  3.5501480e+04],\n",
       "       [-2.7927158e+04,  4.0860746e+04],\n",
       "       [-3.3818582e+04,  4.7137102e+04],\n",
       "       [-3.9626605e+04,  5.3457750e+04],\n",
       "       [-2.1309232e+04,  6.0437129e+04],\n",
       "       [-1.7823463e+04,  6.5404715e+04],\n",
       "       [-2.5725193e+04,  7.2588164e+04]], dtype=float32)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "motion_modes[40]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "310033 125171 310168 125190 308938 114493"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ros_perception",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
